{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "rQsYkXeIkL6d"
      },
      "outputs": [],
      "source": [
        "#@title ##### License\n",
        "# Copyright 2018 The GraphNets Authors. All Rights Reserved.\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#    http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or  implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License.\n",
        "# ============================================================================"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "bXBusmrp1vaL"
      },
      "source": [
        "# Find the shortest path in a graph\n",
        "This notebook and the accompanying code demonstrates how to use the Graph Nets library to learn to predict the shortest path between two nodes in graph.\n",
        "\n",
        "The network is trained to label the nodes and edges of the shortest path, given the start and end nodes.\n",
        "\n",
        "After training, the network's prediction ability is illustrated by comparing its output to the true shortest path. Then the network's ability to generalise is tested, by using it to predict the shortest path in similar but larger graphs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "FlBiBDZjK-Tl"
      },
      "outputs": [],
      "source": [
        "#@title ### Install the Graph Nets library on this Colaboratory runtime  { form-width: \"60%\", run: \"auto\"}\n",
        "#@markdown \u003cbr\u003e1. Connect to a local or hosted Colaboratory runtime by clicking the **Connect** button at the top-right.\u003cbr\u003e2. Choose \"Yes\" below to install the Graph Nets library on the runtime machine with the correct dependencies. Note, this works both with local and hosted Colaboratory runtimes.\n",
        "\n",
        "install_graph_nets_library = \"No\"  #@param [\"Yes\", \"No\"]\n",
        "\n",
        "if install_graph_nets_library.lower() == \"yes\":\n",
        "  print(\"Installing Graph Nets library and dependencies:\")\n",
        "  print(\"Output message from command:\\n\")\n",
        "  !pip install graph_nets \"dm-sonnet\u003c2\" \"tensorflow_probability\u003c0.9\"\n",
        "else:\n",
        "  print(\"Skipping installation of Graph Nets library\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "31YqFsfHGab3"
      },
      "source": [
        "### Install dependencies locally\n",
        "\n",
        "If you are running this notebook locally (i.e., not through Colaboratory), you will also need to install a few more dependencies. Run the following on the command line to install the graph networks library, as well as a few other dependencies:\n",
        "\n",
        "```\n",
        "pip install graph_nets matplotlib scipy \"tensorflow\u003e=1.15,\u003c2\" \"dm-sonnet\u003c2\" \"tensorflow_probability\u003c0.9\"\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "ntNJc6x_F4u5"
      },
      "source": [
        "# Code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "tjd3-8PJdK2m"
      },
      "outputs": [],
      "source": [
        "#@title Imports  { form-width: \"30%\" }\n",
        "%tensorflow_version 1.x  # For Google Colab only.\n",
        "from __future__ import absolute_import\n",
        "from __future__ import division\n",
        "from __future__ import print_function\n",
        "\n",
        "import collections\n",
        "import itertools\n",
        "import time\n",
        "\n",
        "from graph_nets import graphs\n",
        "from graph_nets import utils_np\n",
        "from graph_nets import utils_tf\n",
        "from graph_nets.demos import models\n",
        "import matplotlib.pyplot as plt\n",
        "import networkx as nx\n",
        "import numpy as np\n",
        "from scipy import spatial\n",
        "import tensorflow as tf\n",
        "\n",
        "\n",
        "SEED = 1\n",
        "np.random.seed(SEED)\n",
        "tf.set_random_seed(SEED)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "TrGithqWUML7"
      },
      "outputs": [],
      "source": [
        "#@title Helper functions  { form-width: \"30%\" }\n",
        "\n",
        "# pylint: disable=redefined-outer-name\n",
        "\n",
        "DISTANCE_WEIGHT_NAME = \"distance\"  # The name for the distance edge attribute.\n",
        "\n",
        "\n",
        "def pairwise(iterable):\n",
        "  \"\"\"s -\u003e (s0,s1), (s1,s2), (s2, s3), ...\"\"\"\n",
        "  a, b = itertools.tee(iterable)\n",
        "  next(b, None)\n",
        "  return zip(a, b)\n",
        "\n",
        "\n",
        "def set_diff(seq0, seq1):\n",
        "  \"\"\"Return the set difference between 2 sequences as a list.\"\"\"\n",
        "  return list(set(seq0) - set(seq1))\n",
        "\n",
        "\n",
        "def to_one_hot(indices, max_value, axis=-1):\n",
        "  one_hot = np.eye(max_value)[indices]\n",
        "  if axis not in (-1, one_hot.ndim):\n",
        "    one_hot = np.moveaxis(one_hot, -1, axis)\n",
        "  return one_hot\n",
        "\n",
        "\n",
        "def get_node_dict(graph, attr):\n",
        "  \"\"\"Return a `dict` of node:attribute pairs from a graph.\"\"\"\n",
        "  return {k: v[attr] for k, v in graph.nodes.items()}\n",
        "\n",
        "\n",
        "def generate_graph(rand,\n",
        "                   num_nodes_min_max,\n",
        "                   dimensions=2,\n",
        "                   theta=1000.0,\n",
        "                   rate=1.0):\n",
        "  \"\"\"Creates a connected graph.\n",
        "\n",
        "  The graphs are geographic threshold graphs, but with added edges via a\n",
        "  minimum spanning tree algorithm, to ensure all nodes are connected.\n",
        "\n",
        "  Args:\n",
        "    rand: A random seed for the graph generator. Default= None.\n",
        "    num_nodes_min_max: A sequence [lower, upper) number of nodes per graph.\n",
        "    dimensions: (optional) An `int` number of dimensions for the positions.\n",
        "      Default= 2.\n",
        "    theta: (optional) A `float` threshold parameters for the geographic\n",
        "      threshold graph's threshold. Large values (1000+) make mostly trees. Try\n",
        "      20-60 for good non-trees. Default=1000.0.\n",
        "    rate: (optional) A rate parameter for the node weight exponential sampling\n",
        "      distribution. Default= 1.0.\n",
        "\n",
        "  Returns:\n",
        "    The graph.\n",
        "  \"\"\"\n",
        "  # Sample num_nodes.\n",
        "  num_nodes = rand.randint(*num_nodes_min_max)\n",
        "\n",
        "  # Create geographic threshold graph.\n",
        "  pos_array = rand.uniform(size=(num_nodes, dimensions))\n",
        "  pos = dict(enumerate(pos_array))\n",
        "  weight = dict(enumerate(rand.exponential(rate, size=num_nodes)))\n",
        "  geo_graph = nx.geographical_threshold_graph(\n",
        "      num_nodes, theta, pos=pos, weight=weight)\n",
        "\n",
        "  # Create minimum spanning tree across geo_graph's nodes.\n",
        "  distances = spatial.distance.squareform(spatial.distance.pdist(pos_array))\n",
        "  i_, j_ = np.meshgrid(range(num_nodes), range(num_nodes), indexing=\"ij\")\n",
        "  weighted_edges = list(zip(i_.ravel(), j_.ravel(), distances.ravel()))\n",
        "  mst_graph = nx.Graph()\n",
        "  mst_graph.add_weighted_edges_from(weighted_edges, weight=DISTANCE_WEIGHT_NAME)\n",
        "  mst_graph = nx.minimum_spanning_tree(mst_graph, weight=DISTANCE_WEIGHT_NAME)\n",
        "  # Put geo_graph's node attributes into the mst_graph.\n",
        "  for i in mst_graph.nodes():\n",
        "    mst_graph.nodes[i].update(geo_graph.nodes[i])\n",
        "\n",
        "  # Compose the graphs.\n",
        "  combined_graph = nx.compose_all((mst_graph, geo_graph.copy()))\n",
        "  # Put all distance weights into edge attributes.\n",
        "  for i, j in combined_graph.edges():\n",
        "    combined_graph.get_edge_data(i, j).setdefault(DISTANCE_WEIGHT_NAME,\n",
        "                                                  distances[i, j])\n",
        "  return combined_graph, mst_graph, geo_graph\n",
        "\n",
        "\n",
        "def add_shortest_path(rand, graph, min_length=1):\n",
        "  \"\"\"Samples a shortest path from A to B and adds attributes to indicate it.\n",
        "\n",
        "  Args:\n",
        "    rand: A random seed for the graph generator. Default= None.\n",
        "    graph: A `nx.Graph`.\n",
        "    min_length: (optional) An `int` minimum number of edges in the shortest\n",
        "      path. Default= 1.\n",
        "\n",
        "  Returns:\n",
        "    The `nx.DiGraph` with the shortest path added.\n",
        "\n",
        "  Raises:\n",
        "    ValueError: All shortest paths are below the minimum length\n",
        "  \"\"\"\n",
        "  # Map from node pairs to the length of their shortest path.\n",
        "  pair_to_length_dict = {}\n",
        "  try:\n",
        "    # This is for compatibility with older networkx.\n",
        "    lengths = nx.all_pairs_shortest_path_length(graph).items()\n",
        "  except AttributeError:\n",
        "    # This is for compatibility with newer networkx.\n",
        "    lengths = list(nx.all_pairs_shortest_path_length(graph))\n",
        "  for x, yy in lengths:\n",
        "    for y, l in yy.items():\n",
        "      if l \u003e= min_length:\n",
        "        pair_to_length_dict[x, y] = l\n",
        "  if max(pair_to_length_dict.values()) \u003c min_length:\n",
        "    raise ValueError(\"All shortest paths are below the minimum length\")\n",
        "  # The node pairs which exceed the minimum length.\n",
        "  node_pairs = list(pair_to_length_dict)\n",
        "\n",
        "  # Computes probabilities per pair, to enforce uniform sampling of each\n",
        "  # shortest path lengths.\n",
        "  # The counts of pairs per length.\n",
        "  counts = collections.Counter(pair_to_length_dict.values())\n",
        "  prob_per_length = 1.0 / len(counts)\n",
        "  probabilities = [\n",
        "      prob_per_length / counts[pair_to_length_dict[x]] for x in node_pairs\n",
        "  ]\n",
        "\n",
        "  # Choose the start and end points.\n",
        "  i = rand.choice(len(node_pairs), p=probabilities)\n",
        "  start, end = node_pairs[i]\n",
        "  path = nx.shortest_path(\n",
        "      graph, source=start, target=end, weight=DISTANCE_WEIGHT_NAME)\n",
        "\n",
        "  # Creates a directed graph, to store the directed path from start to end.\n",
        "  digraph = graph.to_directed()\n",
        "\n",
        "  # Add the \"start\", \"end\", and \"solution\" attributes to the nodes and edges.\n",
        "  digraph.add_node(start, start=True)\n",
        "  digraph.add_node(end, end=True)\n",
        "  digraph.add_nodes_from(set_diff(digraph.nodes(), [start]), start=False)\n",
        "  digraph.add_nodes_from(set_diff(digraph.nodes(), [end]), end=False)\n",
        "  digraph.add_nodes_from(set_diff(digraph.nodes(), path), solution=False)\n",
        "  digraph.add_nodes_from(path, solution=True)\n",
        "  path_edges = list(pairwise(path))\n",
        "  digraph.add_edges_from(set_diff(digraph.edges(), path_edges), solution=False)\n",
        "  digraph.add_edges_from(path_edges, solution=True)\n",
        "\n",
        "  return digraph\n",
        "\n",
        "\n",
        "def graph_to_input_target(graph):\n",
        "  \"\"\"Returns 2 graphs with input and target feature vectors for training.\n",
        "\n",
        "  Args:\n",
        "    graph: An `nx.DiGraph` instance.\n",
        "\n",
        "  Returns:\n",
        "    The input `nx.DiGraph` instance.\n",
        "    The target `nx.DiGraph` instance.\n",
        "\n",
        "  Raises:\n",
        "    ValueError: unknown node type\n",
        "  \"\"\"\n",
        "\n",
        "  def create_feature(attr, fields):\n",
        "    return np.hstack([np.array(attr[field], dtype=float) for field in fields])\n",
        "\n",
        "  input_node_fields = (\"pos\", \"weight\", \"start\", \"end\")\n",
        "  input_edge_fields = (\"distance\",)\n",
        "  target_node_fields = (\"solution\",)\n",
        "  target_edge_fields = (\"solution\",)\n",
        "\n",
        "  input_graph = graph.copy()\n",
        "  target_graph = graph.copy()\n",
        "\n",
        "  solution_length = 0\n",
        "  for node_index, node_feature in graph.nodes(data=True):\n",
        "    input_graph.add_node(\n",
        "        node_index, features=create_feature(node_feature, input_node_fields))\n",
        "    target_node = to_one_hot(\n",
        "        create_feature(node_feature, target_node_fields).astype(int), 2)[0]\n",
        "    target_graph.add_node(node_index, features=target_node)\n",
        "    solution_length += int(node_feature[\"solution\"])\n",
        "  solution_length /= graph.number_of_nodes()\n",
        "\n",
        "  for receiver, sender, features in graph.edges(data=True):\n",
        "    input_graph.add_edge(\n",
        "        sender, receiver, features=create_feature(features, input_edge_fields))\n",
        "    target_edge = to_one_hot(\n",
        "        create_feature(features, target_edge_fields).astype(int), 2)[0]\n",
        "    target_graph.add_edge(sender, receiver, features=target_edge)\n",
        "\n",
        "  input_graph.graph[\"features\"] = np.array([0.0])\n",
        "  target_graph.graph[\"features\"] = np.array([solution_length], dtype=float)\n",
        "\n",
        "  return input_graph, target_graph\n",
        "\n",
        "\n",
        "def generate_networkx_graphs(rand, num_examples, num_nodes_min_max, theta):\n",
        "  \"\"\"Generate graphs for training.\n",
        "\n",
        "  Args:\n",
        "    rand: A random seed (np.RandomState instance).\n",
        "    num_examples: Total number of graphs to generate.\n",
        "    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per\n",
        "      graph. The number of nodes for a graph is uniformly sampled within this\n",
        "      range.\n",
        "    theta: (optional) A `float` threshold parameters for the geographic\n",
        "      threshold graph's threshold. Default= the number of nodes.\n",
        "\n",
        "  Returns:\n",
        "    input_graphs: The list of input graphs.\n",
        "    target_graphs: The list of output graphs.\n",
        "    graphs: The list of generated graphs.\n",
        "  \"\"\"\n",
        "  input_graphs = []\n",
        "  target_graphs = []\n",
        "  graphs = []\n",
        "  for _ in range(num_examples):\n",
        "    graph = generate_graph(rand, num_nodes_min_max, theta=theta)[0]\n",
        "    graph = add_shortest_path(rand, graph)\n",
        "    input_graph, target_graph = graph_to_input_target(graph)\n",
        "    input_graphs.append(input_graph)\n",
        "    target_graphs.append(target_graph)\n",
        "    graphs.append(graph)\n",
        "  return input_graphs, target_graphs, graphs\n",
        "\n",
        "\n",
        "def create_placeholders(rand, batch_size, num_nodes_min_max, theta):\n",
        "  \"\"\"Creates placeholders for the model training and evaluation.\n",
        "\n",
        "  Args:\n",
        "    rand: A random seed (np.RandomState instance).\n",
        "    batch_size: Total number of graphs per batch.\n",
        "    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per\n",
        "      graph. The number of nodes for a graph is uniformly sampled within this\n",
        "      range.\n",
        "    theta: A `float` threshold parameters for the geographic threshold graph's\n",
        "      threshold. Default= the number of nodes.\n",
        "\n",
        "  Returns:\n",
        "    input_ph: The input graph's placeholders, as a graph namedtuple.\n",
        "    target_ph: The target graph's placeholders, as a graph namedtuple.\n",
        "  \"\"\"\n",
        "  # Create some example data for inspecting the vector sizes.\n",
        "  input_graphs, target_graphs, _ = generate_networkx_graphs(\n",
        "      rand, batch_size, num_nodes_min_max, theta)\n",
        "  input_ph = utils_tf.placeholders_from_networkxs(input_graphs)\n",
        "  target_ph = utils_tf.placeholders_from_networkxs(target_graphs)\n",
        "  return input_ph, target_ph\n",
        "\n",
        "\n",
        "def create_feed_dict(rand, batch_size, num_nodes_min_max, theta, input_ph,\n",
        "                     target_ph):\n",
        "  \"\"\"Creates placeholders for the model training and evaluation.\n",
        "\n",
        "  Args:\n",
        "    rand: A random seed (np.RandomState instance).\n",
        "    batch_size: Total number of graphs per batch.\n",
        "    num_nodes_min_max: A 2-tuple with the [lower, upper) number of nodes per\n",
        "      graph. The number of nodes for a graph is uniformly sampled within this\n",
        "      range.\n",
        "    theta: A `float` threshold parameters for the geographic threshold graph's\n",
        "      threshold. Default= the number of nodes.\n",
        "    input_ph: The input graph's placeholders, as a graph namedtuple.\n",
        "    target_ph: The target graph's placeholders, as a graph namedtuple.\n",
        "\n",
        "  Returns:\n",
        "    feed_dict: The feed `dict` of input and target placeholders and data.\n",
        "    raw_graphs: The `dict` of raw networkx graphs.\n",
        "  \"\"\"\n",
        "  inputs, targets, raw_graphs = generate_networkx_graphs(\n",
        "      rand, batch_size, num_nodes_min_max, theta)\n",
        "  input_graphs = utils_np.networkxs_to_graphs_tuple(inputs)\n",
        "  target_graphs = utils_np.networkxs_to_graphs_tuple(targets)\n",
        "  feed_dict = {input_ph: input_graphs, target_ph: target_graphs}\n",
        "  return feed_dict, raw_graphs\n",
        "\n",
        "\n",
        "def compute_accuracy(target, output, use_nodes=True, use_edges=False):\n",
        "  \"\"\"Calculate model accuracy.\n",
        "\n",
        "  Returns the number of correctly predicted shortest path nodes and the number\n",
        "  of completely solved graphs (100% correct predictions).\n",
        "\n",
        "  Args:\n",
        "    target: A `graphs.GraphsTuple` that contains the target graph.\n",
        "    output: A `graphs.GraphsTuple` that contains the output graph.\n",
        "    use_nodes: A `bool` indicator of whether to compute node accuracy or not.\n",
        "    use_edges: A `bool` indicator of whether to compute edge accuracy or not.\n",
        "\n",
        "  Returns:\n",
        "    correct: A `float` fraction of correctly labeled nodes/edges.\n",
        "    solved: A `float` fraction of graphs that are completely correctly labeled.\n",
        "\n",
        "  Raises:\n",
        "    ValueError: Nodes or edges (or both) must be used\n",
        "  \"\"\"\n",
        "  if not use_nodes and not use_edges:\n",
        "    raise ValueError(\"Nodes or edges (or both) must be used\")\n",
        "  tdds = utils_np.graphs_tuple_to_data_dicts(target)\n",
        "  odds = utils_np.graphs_tuple_to_data_dicts(output)\n",
        "  cs = []\n",
        "  ss = []\n",
        "  for td, od in zip(tdds, odds):\n",
        "    xn = np.argmax(td[\"nodes\"], axis=-1)\n",
        "    yn = np.argmax(od[\"nodes\"], axis=-1)\n",
        "    xe = np.argmax(td[\"edges\"], axis=-1)\n",
        "    ye = np.argmax(od[\"edges\"], axis=-1)\n",
        "    c = []\n",
        "    if use_nodes:\n",
        "      c.append(xn == yn)\n",
        "    if use_edges:\n",
        "      c.append(xe == ye)\n",
        "    c = np.concatenate(c, axis=0)\n",
        "    s = np.all(c)\n",
        "    cs.append(c)\n",
        "    ss.append(s)\n",
        "  correct = np.mean(np.concatenate(cs, axis=0))\n",
        "  solved = np.mean(np.stack(ss))\n",
        "  return correct, solved\n",
        "\n",
        "\n",
        "def create_loss_ops(target_op, output_ops):\n",
        "  loss_ops = [\n",
        "      tf.losses.softmax_cross_entropy(target_op.nodes, output_op.nodes) +\n",
        "      tf.losses.softmax_cross_entropy(target_op.edges, output_op.edges)\n",
        "      for output_op in output_ops\n",
        "  ]\n",
        "  return loss_ops\n",
        "\n",
        "\n",
        "def make_all_runnable_in_session(*args):\n",
        "  \"\"\"Lets an iterable of TF graphs be output from a session as NP graphs.\"\"\"\n",
        "  return [utils_tf.make_runnable_in_session(a) for a in args]\n",
        "\n",
        "\n",
        "class GraphPlotter(object):\n",
        "\n",
        "  def __init__(self, ax, graph, pos):\n",
        "    self._ax = ax\n",
        "    self._graph = graph\n",
        "    self._pos = pos\n",
        "    self._base_draw_kwargs = dict(G=self._graph, pos=self._pos, ax=self._ax)\n",
        "    self._solution_length = None\n",
        "    self._nodes = None\n",
        "    self._edges = None\n",
        "    self._start_nodes = None\n",
        "    self._end_nodes = None\n",
        "    self._solution_nodes = None\n",
        "    self._intermediate_solution_nodes = None\n",
        "    self._solution_edges = None\n",
        "    self._non_solution_nodes = None\n",
        "    self._non_solution_edges = None\n",
        "    self._ax.set_axis_off()\n",
        "\n",
        "  @property\n",
        "  def solution_length(self):\n",
        "    if self._solution_length is None:\n",
        "      self._solution_length = len(self._solution_edges)\n",
        "    return self._solution_length\n",
        "\n",
        "  @property\n",
        "  def nodes(self):\n",
        "    if self._nodes is None:\n",
        "      self._nodes = self._graph.nodes()\n",
        "    return self._nodes\n",
        "\n",
        "  @property\n",
        "  def edges(self):\n",
        "    if self._edges is None:\n",
        "      self._edges = self._graph.edges()\n",
        "    return self._edges\n",
        "\n",
        "  @property\n",
        "  def start_nodes(self):\n",
        "    if self._start_nodes is None:\n",
        "      self._start_nodes = [\n",
        "          n for n in self.nodes if self._graph.nodes[n].get(\"start\", False)\n",
        "      ]\n",
        "    return self._start_nodes\n",
        "\n",
        "  @property\n",
        "  def end_nodes(self):\n",
        "    if self._end_nodes is None:\n",
        "      self._end_nodes = [\n",
        "          n for n in self.nodes if self._graph.nodes[n].get(\"end\", False)\n",
        "      ]\n",
        "    return self._end_nodes\n",
        "\n",
        "  @property\n",
        "  def solution_nodes(self):\n",
        "    if self._solution_nodes is None:\n",
        "      self._solution_nodes = [\n",
        "          n for n in self.nodes if self._graph.nodes[n].get(\"solution\", False)\n",
        "      ]\n",
        "    return self._solution_nodes\n",
        "\n",
        "  @property\n",
        "  def intermediate_solution_nodes(self):\n",
        "    if self._intermediate_solution_nodes is None:\n",
        "      self._intermediate_solution_nodes = [\n",
        "          n for n in self.nodes\n",
        "          if self._graph.nodes[n].get(\"solution\", False) and\n",
        "          not self._graph.nodes[n].get(\"start\", False) and\n",
        "          not self._graph.nodes[n].get(\"end\", False)\n",
        "      ]\n",
        "    return self._intermediate_solution_nodes\n",
        "\n",
        "  @property\n",
        "  def solution_edges(self):\n",
        "    if self._solution_edges is None:\n",
        "      self._solution_edges = [\n",
        "          e for e in self.edges\n",
        "          if self._graph.get_edge_data(e[0], e[1]).get(\"solution\", False)\n",
        "      ]\n",
        "    return self._solution_edges\n",
        "\n",
        "  @property\n",
        "  def non_solution_nodes(self):\n",
        "    if self._non_solution_nodes is None:\n",
        "      self._non_solution_nodes = [\n",
        "          n for n in self.nodes\n",
        "          if not self._graph.nodes[n].get(\"solution\", False)\n",
        "      ]\n",
        "    return self._non_solution_nodes\n",
        "\n",
        "  @property\n",
        "  def non_solution_edges(self):\n",
        "    if self._non_solution_edges is None:\n",
        "      self._non_solution_edges = [\n",
        "          e for e in self.edges\n",
        "          if not self._graph.get_edge_data(e[0], e[1]).get(\"solution\", False)\n",
        "      ]\n",
        "    return self._non_solution_edges\n",
        "\n",
        "  def _make_draw_kwargs(self, **kwargs):\n",
        "    kwargs.update(self._base_draw_kwargs)\n",
        "    return kwargs\n",
        "\n",
        "  def _draw(self, draw_function, zorder=None, **kwargs):\n",
        "    draw_kwargs = self._make_draw_kwargs(**kwargs)\n",
        "    collection = draw_function(**draw_kwargs)\n",
        "    if collection is not None and zorder is not None:\n",
        "      try:\n",
        "        # This is for compatibility with older matplotlib.\n",
        "        collection.set_zorder(zorder)\n",
        "      except AttributeError:\n",
        "        # This is for compatibility with newer matplotlib.\n",
        "        collection[0].set_zorder(zorder)\n",
        "    return collection\n",
        "\n",
        "  def draw_nodes(self, **kwargs):\n",
        "    \"\"\"Useful kwargs: nodelist, node_size, node_color, linewidths.\"\"\"\n",
        "    if (\"node_color\" in kwargs and\n",
        "        isinstance(kwargs[\"node_color\"], collections.Sequence) and\n",
        "        len(kwargs[\"node_color\"]) in {3, 4} and\n",
        "        not isinstance(kwargs[\"node_color\"][0],\n",
        "                       (collections.Sequence, np.ndarray))):\n",
        "      num_nodes = len(kwargs.get(\"nodelist\", self.nodes))\n",
        "      kwargs[\"node_color\"] = np.tile(\n",
        "          np.array(kwargs[\"node_color\"])[None], [num_nodes, 1])\n",
        "    return self._draw(nx.draw_networkx_nodes, **kwargs)\n",
        "\n",
        "  def draw_edges(self, **kwargs):\n",
        "    \"\"\"Useful kwargs: edgelist, width.\"\"\"\n",
        "    return self._draw(nx.draw_networkx_edges, **kwargs)\n",
        "\n",
        "  def draw_graph(self,\n",
        "                 node_size=200,\n",
        "                 node_color=(0.4, 0.8, 0.4),\n",
        "                 node_linewidth=1.0,\n",
        "                 edge_width=1.0):\n",
        "    # Plot nodes.\n",
        "    self.draw_nodes(\n",
        "        nodelist=self.nodes,\n",
        "        node_size=node_size,\n",
        "        node_color=node_color,\n",
        "        linewidths=node_linewidth,\n",
        "        zorder=20)\n",
        "    # Plot edges.\n",
        "    self.draw_edges(edgelist=self.edges, width=edge_width, zorder=10)\n",
        "\n",
        "  def draw_graph_with_solution(self,\n",
        "                               node_size=200,\n",
        "                               node_color=(0.4, 0.8, 0.4),\n",
        "                               node_linewidth=1.0,\n",
        "                               edge_width=1.0,\n",
        "                               start_color=\"w\",\n",
        "                               end_color=\"k\",\n",
        "                               solution_node_linewidth=3.0,\n",
        "                               solution_edge_width=3.0):\n",
        "    node_border_color = (0.0, 0.0, 0.0, 1.0)\n",
        "    node_collections = {}\n",
        "    # Plot start nodes.\n",
        "    node_collections[\"start nodes\"] = self.draw_nodes(\n",
        "        nodelist=self.start_nodes,\n",
        "        node_size=node_size,\n",
        "        node_color=start_color,\n",
        "        linewidths=solution_node_linewidth,\n",
        "        edgecolors=node_border_color,\n",
        "        zorder=100)\n",
        "    # Plot end nodes.\n",
        "    node_collections[\"end nodes\"] = self.draw_nodes(\n",
        "        nodelist=self.end_nodes,\n",
        "        node_size=node_size,\n",
        "        node_color=end_color,\n",
        "        linewidths=solution_node_linewidth,\n",
        "        edgecolors=node_border_color,\n",
        "        zorder=90)\n",
        "    # Plot intermediate solution nodes.\n",
        "    if isinstance(node_color, dict):\n",
        "      c = [node_color[n] for n in self.intermediate_solution_nodes]\n",
        "    else:\n",
        "      c = node_color\n",
        "    node_collections[\"intermediate solution nodes\"] = self.draw_nodes(\n",
        "        nodelist=self.intermediate_solution_nodes,\n",
        "        node_size=node_size,\n",
        "        node_color=c,\n",
        "        linewidths=solution_node_linewidth,\n",
        "        edgecolors=node_border_color,\n",
        "        zorder=80)\n",
        "    # Plot solution edges.\n",
        "    node_collections[\"solution edges\"] = self.draw_edges(\n",
        "        edgelist=self.solution_edges, width=solution_edge_width, zorder=70)\n",
        "    # Plot non-solution nodes.\n",
        "    if isinstance(node_color, dict):\n",
        "      c = [node_color[n] for n in self.non_solution_nodes]\n",
        "    else:\n",
        "      c = node_color\n",
        "    node_collections[\"non-solution nodes\"] = self.draw_nodes(\n",
        "        nodelist=self.non_solution_nodes,\n",
        "        node_size=node_size,\n",
        "        node_color=c,\n",
        "        linewidths=node_linewidth,\n",
        "        edgecolors=node_border_color,\n",
        "        zorder=20)\n",
        "    # Plot non-solution edges.\n",
        "    node_collections[\"non-solution edges\"] = self.draw_edges(\n",
        "        edgelist=self.non_solution_edges, width=edge_width, zorder=10)\n",
        "    # Set title as solution length.\n",
        "    self._ax.set_title(\"Solution length: {}\".format(self.solution_length))\n",
        "    return node_collections\n",
        "\n",
        "\n",
        "# pylint: enable=redefined-outer-name"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "6oEV1OC3UQAc"
      },
      "outputs": [],
      "source": [
        "#@title Visualize example graphs  { form-width: \"30%\" }\n",
        "seed = 1  #@param{type: 'integer'}\n",
        "rand = np.random.RandomState(seed=seed)\n",
        "\n",
        "num_examples = 15  #@param{type: 'integer'}\n",
        "# Large values (1000+) make trees. Try 20-60 for good non-trees.\n",
        "theta = 20  #@param{type: 'integer'}\n",
        "num_nodes_min_max = (16, 17)\n",
        "\n",
        "input_graphs, target_graphs, graphs = generate_networkx_graphs(\n",
        "    rand, num_examples, num_nodes_min_max, theta)\n",
        "\n",
        "num = min(num_examples, 16)\n",
        "w = 3\n",
        "h = int(np.ceil(num / w))\n",
        "fig = plt.figure(40, figsize=(w * 4, h * 4))\n",
        "fig.clf()\n",
        "for j, graph in enumerate(graphs):\n",
        "  ax = fig.add_subplot(h, w, j + 1)\n",
        "  pos = get_node_dict(graph, \"pos\")\n",
        "  plotter = GraphPlotter(ax, graph, pos)\n",
        "  plotter.draw_graph_with_solution()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "cY09Bll0vuVj"
      },
      "outputs": [],
      "source": [
        "#@title Set up model training and evaluation  { form-width: \"30%\" }\n",
        "\n",
        "# The model we explore includes three components:\n",
        "# - An \"Encoder\" graph net, which independently encodes the edge, node, and\n",
        "#   global attributes (does not compute relations etc.).\n",
        "# - A \"Core\" graph net, which performs N rounds of processing (message-passing)\n",
        "#   steps. The input to the Core is the concatenation of the Encoder's output\n",
        "#   and the previous output of the Core (labeled \"Hidden(t)\" below, where \"t\" is\n",
        "#   the processing step).\n",
        "# - A \"Decoder\" graph net, which independently decodes the edge, node, and\n",
        "#   global attributes (does not compute relations etc.), on each\n",
        "#   message-passing step.\n",
        "#\n",
        "#                     Hidden(t)   Hidden(t+1)\n",
        "#                        |            ^\n",
        "#           *---------*  |  *------*  |  *---------*\n",
        "#           |         |  |  |      |  |  |         |\n",
        "# Input ---\u003e| Encoder |  *-\u003e| Core |--*-\u003e| Decoder |---\u003e Output(t)\n",
        "#           |         |----\u003e|      |     |         |\n",
        "#           *---------*     *------*     *---------*\n",
        "#\n",
        "# The model is trained by supervised learning. Input graphs are procedurally\n",
        "# generated, and output graphs have the same structure with the nodes and edges\n",
        "# of the shortest path labeled (using 2-element 1-hot vectors). We could have\n",
        "# predicted the shortest path only by labeling either the nodes or edges, and\n",
        "# that does work, but we decided to predict both to demonstrate the flexibility\n",
        "# of graph nets' outputs.\n",
        "#\n",
        "# The training loss is computed on the output of each processing step. The\n",
        "# reason for this is to encourage the model to try to solve the problem in as\n",
        "# few steps as possible. It also helps make the output of intermediate steps\n",
        "# more interpretable.\n",
        "#\n",
        "# There's no need for a separate evaluate dataset because the inputs are\n",
        "# never repeated, so the training loss is the measure of performance on graphs\n",
        "# from the input distribution.\n",
        "#\n",
        "# We also evaluate how well the models generalize to graphs which are up to\n",
        "# twice as large as those on which it was trained. The loss is computed only\n",
        "# on the final processing step.\n",
        "#\n",
        "# Variables with the suffix _tr are training parameters, and variables with the\n",
        "# suffix _ge are test/generalization parameters.\n",
        "#\n",
        "# After around 2000-5000 training iterations the model reaches near-perfect\n",
        "# performance on graphs with between 8-16 nodes.\n",
        "\n",
        "tf.reset_default_graph()\n",
        "\n",
        "seed = 2\n",
        "rand = np.random.RandomState(seed=seed)\n",
        "\n",
        "# Model parameters.\n",
        "# Number of processing (message-passing) steps.\n",
        "num_processing_steps_tr = 10\n",
        "num_processing_steps_ge = 10\n",
        "\n",
        "# Data / training parameters.\n",
        "num_training_iterations = 10000\n",
        "theta = 20  # Large values (1000+) make trees. Try 20-60 for good non-trees.\n",
        "batch_size_tr = 32\n",
        "batch_size_ge = 100\n",
        "# Number of nodes per graph sampled uniformly from this range.\n",
        "num_nodes_min_max_tr = (8, 17)\n",
        "num_nodes_min_max_ge = (16, 33)\n",
        "\n",
        "# Data.\n",
        "# Input and target placeholders.\n",
        "input_ph, target_ph = create_placeholders(rand, batch_size_tr,\n",
        "                                          num_nodes_min_max_tr, theta)\n",
        "\n",
        "# Connect the data to the model.\n",
        "# Instantiate the model.\n",
        "model = models.EncodeProcessDecode(edge_output_size=2, node_output_size=2)\n",
        "# A list of outputs, one per processing step.\n",
        "output_ops_tr = model(input_ph, num_processing_steps_tr)\n",
        "output_ops_ge = model(input_ph, num_processing_steps_ge)\n",
        "\n",
        "# Training loss.\n",
        "loss_ops_tr = create_loss_ops(target_ph, output_ops_tr)\n",
        "# Loss across processing steps.\n",
        "loss_op_tr = sum(loss_ops_tr) / num_processing_steps_tr\n",
        "# Test/generalization loss.\n",
        "loss_ops_ge = create_loss_ops(target_ph, output_ops_ge)\n",
        "loss_op_ge = loss_ops_ge[-1]  # Loss from final processing step.\n",
        "\n",
        "# Optimizer.\n",
        "learning_rate = 1e-3\n",
        "optimizer = tf.train.AdamOptimizer(learning_rate)\n",
        "step_op = optimizer.minimize(loss_op_tr)\n",
        "\n",
        "# Lets an iterable of TF graphs be output from a session as NP graphs.\n",
        "input_ph, target_ph = make_all_runnable_in_session(input_ph, target_ph)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "WoVdyUTjvzWb"
      },
      "outputs": [],
      "source": [
        "#@title Reset session  { form-width: \"30%\" }\n",
        "\n",
        "# This cell resets the Tensorflow session, but keeps the same computational\n",
        "# graph.\n",
        "\n",
        "try:\n",
        "  sess.close()\n",
        "except NameError:\n",
        "  pass\n",
        "sess = tf.Session()\n",
        "sess.run(tf.global_variables_initializer())\n",
        "\n",
        "last_iteration = 0\n",
        "logged_iterations = []\n",
        "losses_tr = []\n",
        "corrects_tr = []\n",
        "solveds_tr = []\n",
        "losses_ge = []\n",
        "corrects_ge = []\n",
        "solveds_ge = []"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "wWSqSYyQv0Ur"
      },
      "outputs": [],
      "source": [
        "#@title Run training  { form-width: \"30%\" }\n",
        "\n",
        "# You can interrupt this cell's training loop at any time, and visualize the\n",
        "# intermediate results by running the next cell (below). You can then resume\n",
        "# training by simply executing this cell again.\n",
        "\n",
        "# How much time between logging and printing the current results.\n",
        "log_every_seconds = 20\n",
        "\n",
        "print(\"# (iteration number), T (elapsed seconds), \"\n",
        "      \"Ltr (training loss), Lge (test/generalization loss), \"\n",
        "      \"Ctr (training fraction nodes/edges labeled correctly), \"\n",
        "      \"Str (training fraction examples solved correctly), \"\n",
        "      \"Cge (test/generalization fraction nodes/edges labeled correctly), \"\n",
        "      \"Sge (test/generalization fraction examples solved correctly)\")\n",
        "\n",
        "start_time = time.time()\n",
        "last_log_time = start_time\n",
        "for iteration in range(last_iteration, num_training_iterations):\n",
        "  last_iteration = iteration\n",
        "  feed_dict, _ = create_feed_dict(rand, batch_size_tr, num_nodes_min_max_tr,\n",
        "                                  theta, input_ph, target_ph)\n",
        "  train_values = sess.run({\n",
        "      \"step\": step_op,\n",
        "      \"target\": target_ph,\n",
        "      \"loss\": loss_op_tr,\n",
        "      \"outputs\": output_ops_tr\n",
        "  },\n",
        "                          feed_dict=feed_dict)\n",
        "  the_time = time.time()\n",
        "  elapsed_since_last_log = the_time - last_log_time\n",
        "  if elapsed_since_last_log \u003e log_every_seconds:\n",
        "    last_log_time = the_time\n",
        "    feed_dict, raw_graphs = create_feed_dict(\n",
        "        rand, batch_size_ge, num_nodes_min_max_ge, theta, input_ph, target_ph)\n",
        "    test_values = sess.run({\n",
        "        \"target\": target_ph,\n",
        "        \"loss\": loss_op_ge,\n",
        "        \"outputs\": output_ops_ge\n",
        "    },\n",
        "                           feed_dict=feed_dict)\n",
        "    correct_tr, solved_tr = compute_accuracy(\n",
        "        train_values[\"target\"], train_values[\"outputs\"][-1], use_edges=True)\n",
        "    correct_ge, solved_ge = compute_accuracy(\n",
        "        test_values[\"target\"], test_values[\"outputs\"][-1], use_edges=True)\n",
        "    elapsed = time.time() - start_time\n",
        "    losses_tr.append(train_values[\"loss\"])\n",
        "    corrects_tr.append(correct_tr)\n",
        "    solveds_tr.append(solved_tr)\n",
        "    losses_ge.append(test_values[\"loss\"])\n",
        "    corrects_ge.append(correct_ge)\n",
        "    solveds_ge.append(solved_ge)\n",
        "    logged_iterations.append(iteration)\n",
        "    print(\"# {:05d}, T {:.1f}, Ltr {:.4f}, Lge {:.4f}, Ctr {:.4f}, Str\"\n",
        "          \" {:.4f}, Cge {:.4f}, Sge {:.4f}\".format(\n",
        "              iteration, elapsed, train_values[\"loss\"], test_values[\"loss\"],\n",
        "              correct_tr, solved_tr, correct_ge, solved_ge))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "cellView": "form",
        "colab": {},
        "colab_type": "code",
        "id": "u0ckrMtj72s-"
      },
      "outputs": [],
      "source": [
        "#@title Visualize results  { form-width: \"30%\" }\n",
        "\n",
        "# This cell visualizes the results of training. You can visualize the\n",
        "# intermediate results by interrupting execution of the cell above, and running\n",
        "# this cell. You can then resume training by simply executing the above cell\n",
        "# again.\n",
        "\n",
        "def softmax_prob_last_dim(x):  # pylint: disable=redefined-outer-name\n",
        "  e = np.exp(x)\n",
        "  return e[:, -1] / np.sum(e, axis=-1)\n",
        "\n",
        "\n",
        "# Plot results curves.\n",
        "fig = plt.figure(1, figsize=(18, 3))\n",
        "fig.clf()\n",
        "x = np.array(logged_iterations)\n",
        "# Loss.\n",
        "y_tr = losses_tr\n",
        "y_ge = losses_ge\n",
        "ax = fig.add_subplot(1, 3, 1)\n",
        "ax.plot(x, y_tr, \"k\", label=\"Training\")\n",
        "ax.plot(x, y_ge, \"k--\", label=\"Test/generalization\")\n",
        "ax.set_title(\"Loss across training\")\n",
        "ax.set_xlabel(\"Training iteration\")\n",
        "ax.set_ylabel(\"Loss (binary cross-entropy)\")\n",
        "ax.legend()\n",
        "# Correct.\n",
        "y_tr = corrects_tr\n",
        "y_ge = corrects_ge\n",
        "ax = fig.add_subplot(1, 3, 2)\n",
        "ax.plot(x, y_tr, \"k\", label=\"Training\")\n",
        "ax.plot(x, y_ge, \"k--\", label=\"Test/generalization\")\n",
        "ax.set_title(\"Fraction correct across training\")\n",
        "ax.set_xlabel(\"Training iteration\")\n",
        "ax.set_ylabel(\"Fraction nodes/edges correct\")\n",
        "# Solved.\n",
        "y_tr = solveds_tr\n",
        "y_ge = solveds_ge\n",
        "ax = fig.add_subplot(1, 3, 3)\n",
        "ax.plot(x, y_tr, \"k\", label=\"Training\")\n",
        "ax.plot(x, y_ge, \"k--\", label=\"Test/generalization\")\n",
        "ax.set_title(\"Fraction solved across training\")\n",
        "ax.set_xlabel(\"Training iteration\")\n",
        "ax.set_ylabel(\"Fraction examples solved\")\n",
        "\n",
        "# Plot graphs and results after each processing step.\n",
        "# The white node is the start, and the black is the end. Other nodes are colored\n",
        "# from red to purple to blue, where red means the model is confident the node is\n",
        "# off the shortest path, blue means the model is confident the node is on the\n",
        "# shortest path, and purplish colors mean the model isn't sure.\n",
        "max_graphs_to_plot = 6\n",
        "num_steps_to_plot = 4\n",
        "node_size = 120\n",
        "min_c = 0.3\n",
        "num_graphs = len(raw_graphs)\n",
        "targets = utils_np.graphs_tuple_to_data_dicts(test_values[\"target\"])\n",
        "step_indices = np.floor(\n",
        "    np.linspace(0, num_processing_steps_ge - 1,\n",
        "                num_steps_to_plot)).astype(int).tolist()\n",
        "outputs = list(\n",
        "    zip(*(utils_np.graphs_tuple_to_data_dicts(test_values[\"outputs\"][i])\n",
        "          for i in step_indices)))\n",
        "h = min(num_graphs, max_graphs_to_plot)\n",
        "w = num_steps_to_plot + 1\n",
        "fig = plt.figure(101, figsize=(18, h * 3))\n",
        "fig.clf()\n",
        "ncs = []\n",
        "for j, (graph, target, output) in enumerate(zip(raw_graphs, targets, outputs)):\n",
        "  if j \u003e= h:\n",
        "    break\n",
        "  pos = get_node_dict(graph, \"pos\")\n",
        "  ground_truth = target[\"nodes\"][:, -1]\n",
        "  # Ground truth.\n",
        "  iax = j * (1 + num_steps_to_plot) + 1\n",
        "  ax = fig.add_subplot(h, w, iax)\n",
        "  plotter = GraphPlotter(ax, graph, pos)\n",
        "  color = {}\n",
        "  for i, n in enumerate(plotter.nodes):\n",
        "    color[n] = np.array([1.0 - ground_truth[i], 0.0, ground_truth[i], 1.0\n",
        "                        ]) * (1.0 - min_c) + min_c\n",
        "  plotter.draw_graph_with_solution(node_size=node_size, node_color=color)\n",
        "  ax.set_axis_on()\n",
        "  ax.set_xticks([])\n",
        "  ax.set_yticks([])\n",
        "  try:\n",
        "    ax.set_facecolor([0.9] * 3 + [1.0])\n",
        "  except AttributeError:\n",
        "    ax.set_axis_bgcolor([0.9] * 3 + [1.0])\n",
        "  ax.grid(None)\n",
        "  ax.set_title(\"Ground truth\\nSolution length: {}\".format(\n",
        "      plotter.solution_length))\n",
        "  # Prediction.\n",
        "  for k, outp in enumerate(output):\n",
        "    iax = j * (1 + num_steps_to_plot) + 2 + k\n",
        "    ax = fig.add_subplot(h, w, iax)\n",
        "    plotter = GraphPlotter(ax, graph, pos)\n",
        "    color = {}\n",
        "    prob = softmax_prob_last_dim(outp[\"nodes\"])\n",
        "    for i, n in enumerate(plotter.nodes):\n",
        "      color[n] = np.array([1.0 - prob[n], 0.0, prob[n], 1.0\n",
        "                          ]) * (1.0 - min_c) + min_c\n",
        "    plotter.draw_graph_with_solution(node_size=node_size, node_color=color)\n",
        "    ax.set_title(\"Model-predicted\\nStep {:02d} / {:02d}\".format(\n",
        "        step_indices[k] + 1, step_indices[-1] + 1))"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "last_runtime": {
        "build_target": "//learning/deepmind/dm_python:dm_notebook3",
        "kind": "private"
      },
      "name": "shortest_path.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
